import tensorflow as tf
from tensorflow import keras
from ResNeXt import resnext
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# 报错解决：NotFoundError: No algorithm worked! when using Conv2D
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)


# ------------------------------------ #
# 预测参数设置
# ------------------------------------ #
im_height = 224  # 输入图像的高
im_width = 224   # 输入图像的高
# 分类名称
class_names = ['forbiden', 'warning', 'goahead', 'slow']
# 权重路径
weight_dir = 'save_weights/resnext.h5'

# ------------------------------------ #
# 单张图片预测
# ------------------------------------ #
# 是否只预测一张图
single_pic = False
# 图像所在文件夹的路径
single_filepath = 'D:/deeplearning/test/数据集/交通标志/new_data/test/禁令标志/'  
# 指定某张图片
picture = single_filepath + '010_0001.png'

# ------------------------------------ #
# 对测试集图片预测
# ------------------------------------ #
test_pack = True
# 验证集文件夹路径
test_filepath = 'D:/deeplearning/test/数据集/交通标志/new_data/test/'


#（1）载入模型
model = resnext(input_shape=[224,224,3], classes=4)  # 模型的输入shape和输出分类数
print('model is loaded')

#（2）载入权重.h文件
model.load_weights(weight_dir)
print('weights is loaded')

#（3）只对单张图像预测
if single_pic is True:
    
    # 加载图片
    img = Image.open(picture)
    # 改变图片size
    img = img.resize((im_height, im_width))
    # 展示图像
    plt.figure()
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])
    
    # 图像像素值归一化处理
    img = np.array(img) / 255.0
    
    # 输入网络的要求，给图像增加一个batch维度, [h,w,c]==>[b,h,w,c]
    img = np.expand_dims(img, axis=0)

    # 预测图片，返回结果包含batch维度[b,n]
    result = model.predict(img)
    # 转换成一维，挤压掉batch维度
    result = np.squeeze(result)
    
    # 找到概率最大值对应的索引
    predict_class = np.argmax(result)
    
    # 打印预测类别及概率
    print('class:', class_names[predict_class], 
          'prob:', result[predict_class])
    
    plt.title(f'{class_names[predict_class]}')
    plt.show()

#（4）对测试集图像预测
if test_pack is True:
    
    # 载入测试集
    test_ds = keras.preprocessing.image_dataset_from_directory(
        directory = test_filepath, 
        label_mode = 'int',  # 不经过ont编码, 1、2、3、4、、、 
        image_size = (im_height, im_width),  # 测试集的图像resize
        batch_size = 32)  # 每批次32张图
    
    # 测试机预处理
    #（2）数据预处理
    def processing(image, label): 
        image = tf.cast(image, tf.float32) / 255.0  #[0,1]之间
        label = tf.cast(label, tf.int32)  # 修改数据类型
        return (image, label)
 
    test_ds = test_ds.map(processing) # 预处理


    test_true = []  # 存放真实值
    test_pred = []  # 存放预测值
    
    # 遍历测试集所有的batch
    for imgs, labels in test_ds:
        # 每次每次取出一个batch的一张图像和一个标签
        for img, label in zip(imgs, labels):
            
            # 网络输入的要求，给图像增加一个维度[h,w,c]==>[b,h,w,c]
            image_array = tf.expand_dims(img, axis=0)
            # 预测某一张图片，返回图片属于许多类别的概率
            prediction = model.predict(image_array)
            
            # 找到预测概率最大的索引对应的类别
            test_pred.append(class_names[np.argmax(prediction)])
            # label是真实标签索引
            test_true.append(class_names[label])
            
    # 展示结果
    print('真实值: ', test_true[:10])
    print('预测值: ', test_pred[:10])
    
    # 绘制混淆矩阵
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    import pandas as pd
    plt.rcParams['font.sans-serif'] = ['SimSun']  #宋体
    plt.rcParams['font.size'] = 15  #设置字体大小
    
    # 生成混淆矩阵
    conf_numpy = confusion_matrix(test_true, test_pred)
    # 转换成DataFrame表格类型，设置行列标签
    conf_df = pd.DataFrame(conf_numpy, index=class_names, columns=class_names)
    
    # 创建绘图区
    plt.figure(figsize=(8,7))
    
    # 生成热力图
    sns.heatmap(conf_df, annot=True, fmt="d", cmap="BuPu")
    
    # 设置标签
    plt.title('Confusion_Matrix')
    plt.xlabel('Predict')
    plt.ylabel('True')

